import logging
from typing import Dict, Sequence

from malib.utils.typing import PolicyID, AgentID, BehaviorMode
from open_spiel.python.policy import (
    Policy as OSPolicy,
    TabularPolicy,
)
from open_spiel.python.algorithms.exploitability import nash_conv

try:
    import pyspiel
except ImportError as e:
    logging.warning(
        "Cannot import open spiel, if you wanna run meta game experiment, please install it before that."
    )

import numpy as np
from malib.algorithm.common.policy import Policy


# inference: https://github.com/JBLanier/pipeline-psro


def build_open_spiel_policy(malib_policy: Policy, open_spiel_game):
    """Wrap a policy instance to match the open_spiel API"""

    def policy_callable(state: pyspiel.State):
        valid_actions = state.legal_actions_mask()
        legal_actions_list = state.legal_actions()
        info_state_vector = state.information_state_tensor()
        # FIXME(ming): check action mask available
        obs_info = {"action_mask": np.asarray([valid_actions], dtype=np.float32)}
        obs = malib_policy.preprocessor.transform(
            {
                "observation": np.asarray(info_state_vector),
                "action_mask": np.asarray(valid_actions),
            }
        )
        # malib_policy.eval()
        obs_info["behavior_mode"] = BehaviorMode.EXPLOITATION
        obs_info["legal_actions_list"] = legal_actions_list
        _, action_probs, extra_info = malib_policy.compute_action(
            observation=[obs], **obs_info
        )

        legal_action_probs = []
        for idx in range(len(valid_actions)):
            if valid_actions[idx] == 1.0:
                legal_action_probs.append(action_probs[0][idx])

        return {
            action_name: action_prob
            for action_name, action_prob in zip(legal_actions_list, legal_action_probs)
        }

    return PolicyFromCallable(game=open_spiel_game, callable_policy=policy_callable)


def tabular_policy_from_weighted_policies(game, policy_iterable, weights):
    assert np.isclose(1.0, sum(weights))

    empty_tabular_policy = TabularPolicy(game)
    empty_tabular_policy.action_probability_array = np.zeros_like(
        empty_tabular_policy.action_probability_array
    )

    for (
        policy,
        weight,
    ) in zip(policy_iterable, weights):
        for state_index, state in enumerate(empty_tabular_policy.states):
            old_action_probabilities = empty_tabular_policy.action_probabilities(state)
            add_action_probabilities = policy.action_probabilities(state)
            infostate_policy = [
                old_action_probabilities.get(action, 0.0)
                + add_action_probabilities.get(action, 0.0) * weight
                for action in range(game.num_distinct_actions())
            ]
            empty_tabular_policy.action_probability_array[
                state_index, :
            ] = infostate_policy

    # check that all action probs pers state add up to one in the newly created policy
    for state_index, state in enumerate(empty_tabular_policy.states):
        action_probabilities = empty_tabular_policy.action_probabilities(state)
        infostate_policy = [
            action_probabilities.get(action, 0.0)
            for action in range(game.num_distinct_actions())
        ]

        assert np.isclose(1.0, sum(infostate_policy)), "INFOSTATE POLICY: {}".format(
            infostate_policy
        )

    return empty_tabular_policy


def measure_exploitability(
    game_name: str,
    populations: Dict[AgentID, Dict[PolicyID, Policy]],
    policy_mixture_dict: Dict[PolicyID, Dict[PolicyID, float]],
):
    """Exploitability calculation for sequential games."""

    open_spiel_game = pyspiel.load_game(game_name)

    def policy_iterable(mpid):
        for pid in policy_mixture_dict[mpid]:
            single_open_spiel_policy = build_open_spiel_policy(
                populations[mpid][pid], open_spiel_game
            )
            yield single_open_spiel_policy

    policies = [
        tabular_policy_from_weighted_policies(
            open_spiel_game, policy_iterable(aid), policy_mixture_dict[aid].values()
        )
        for aid in populations
    ]
    open_spiel_policy = NFSPPolicies(open_spiel_game, policies)

    # return exploitability(game=open_spiel_game, policy=open_spiel_policy)
    return nash_conv(
        game=open_spiel_game, policy=open_spiel_policy, return_only_nash_conv=False
    )


class PolicyFromCallable(OSPolicy):
    """For backwards-compatibility reasons, create a policy from a callable."""

    def __init__(self, game, callable_policy):
        # When creating a Policy from a pyspiel_policy, we do not have the game.
        if game is None:
            all_players = None
        else:
            all_players = list(range(game.num_players()))
        super(PolicyFromCallable, self).__init__(game, all_players)
        self._callable_policy = callable_policy

    def action_probabilities(self, state, player_id=None):
        return dict(self._callable_policy(state))


class NFSPPolicies(OSPolicy):
    """Joint policy to be evaluated."""

    def __init__(self, game, nfsp_policies: Sequence[OSPolicy]):
        player_ids = [0, 1]
        super(NFSPPolicies, self).__init__(game, player_ids)
        self._policies = nfsp_policies
        self._obs = {"info_state": [None, None], "legal_actions": [None, None]}

    def action_probabilities(self, state, player_id=None):
        cur_player = state.current_player()
        prob_dict = self._policies[cur_player].action_probabilities(state, player_id)
        return prob_dict
